import json
import time
import pandas as pd
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report
import torch
import os
import re

model_path = "xxx"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model.eval()

data = pd.read_csv("xxx.csv")

sarcasm_mapping = {
    0: "not sarcasm",
    1: "sarcasm",
}
reverse_sarcasm_mapping = {v: k for k, v in sarcasm_mapping.items()}

def build_prompt(text: str) -> str:
    sarcasm_definition = (
        "1=sarcasm: contains features like surface praise with underlying criticism, "
        "contextual incongruity, exaggerated contrast, etc. | "
        "0=not sarcasm"
    )
    
    return (
        f"Perform step-by-step reasoning to identify the sarcasm in the given text. "
        f"After your reasoning, output the final sarcasm label in the exact format: 'Sarcasm: <label>'.\n\n"
        f"Sarcasm Definitions:\n{sarcasm_definition}\n\n"
        f"Text: \"{text}\"\n\n"
        "Reasoning: Let's think step by step. First, I need to analyze the context and linguistic cues..."
    )

def save_results(results, output_file):
    temp_file = output_file + '.temp'
    with open(temp_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    if os.path.exists(output_file):
        os.remove(output_file)
    os.rename(temp_file, output_file)

texts = data["Text"].astype(str).tolist()
true_labels = data["Label"].astype(int).tolist()

results = []
output_file = "xxx.json"

if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        results = json.load(f)
    processed_indices = {item['index'] for item in results}
else:
    processed_indices = set()

total_correct = 0
total_tokens = 0

for index, text in enumerate(tqdm(texts)):
    if index in processed_indices:
        continue
        
    prompt = build_prompt(text)
    true_label = true_labels[index]
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=16384,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        reply = full_response.replace(prompt, "").strip()
        
        reasoning_part = reply.split("Reasoning:")[-1] if "Reasoning:" in reply else reply
        token_count = len(tokenizer.encode(reasoning_part)) 
        
        sarcasm_match = re.search(r"Sarcasm:\s*([01]|\w+)", reply, re.IGNORECASE)
        if sarcasm_match:
            pred_str = sarcasm_match.group(1).lower()
            if pred_str.isdigit():
                pred_label = int(pred_str)
            else:
                pred_label = 1 if "sarcasm" in pred_str else 0
        else:
            num_match = re.search(r"\b[01]\b", reply)
            pred_label = int(num_match.group(0)) if num_match else 0
        
        is_correct = (pred_label == true_label)
        if is_correct:
            total_correct += 1
        total_tokens += token_count
        
        result_entry = {
            "index": index,
            "text": text,
            "true_label": sarcasm_mapping.get(true_label, "unknown"),
            "predicted_label": sarcasm_mapping.get(pred_label, "unknown"),
            "is_correct": is_correct,
            "full_response": reply,
            "token_count": token_count,
            "prompt": prompt
        }
        
        results.append(result_entry)
        
        if index % 5 == 0:
            save_results(results, output_file)
            
    except Exception as e:
        results.append({
            "index": index,
            "text": text,
            "true_label": sarcasm_mapping.get(true_label, "unknown"),
            "predicted_label": "error",
            "is_correct": False,
            "full_response": str(e),
            "token_count": 0,
            "prompt": prompt
        })
        time.sleep(2)

save_results(results, output_file)

accuracy = total_correct / len(texts) if texts else 0
avg_tokens = total_tokens / len(texts) if texts else 0

pred_labels = [reverse_sarcasm_mapping.get(item['predicted_label'], 0) for item in results] 
print(classification_report(true_labels, pred_labels, digits=4))

df = pd.DataFrame({
    "text": texts,
    "true_label": [sarcasm_mapping.get(label, "unknown") for label in true_labels],
    "pred_label": [item['predicted_label'] for item in results],
    "is_correct": [item['is_correct'] for item in results],
    "token_count": [item['token_count'] for item in results]
})
df.to_csv("full_predictions_sarcasm_ppo_model.csv", index=False, encoding="utf-8-sig")